Revisiting Deep Learning Models for Tabular Data - 论文阅读
Date:
论文: 重新审视表格数据的深度学习模型 论文地址:https://arxiv.org/pdf/2106.11959
摘要
现有的关于表格数据深度学习的文献提出了各种新颖的架构,并在多个数据集上报告了竞争性的结果。然而,这些模型通常没有被适当地相互比较,并且现有的研究通常使用不同的基准和实验协议。因此,对于研究人员和实践者来说,哪个模型表现最佳尚不清楚。此外,该领域仍然缺乏有效的基线,即能够在不同问题上提供竞争性能的易于使用的模型。在这项工作中,我们对表格数据的主要深度学习架构进行了概述,并通过确定两个简单而强大的深度架构来提高表格数据深度学习的基线。第一个架构是类似ResNet的架构,它被证明是一个强大的基线,通常在先前的工作中缺失。第二个模型是我们对Transformer架构的简单适应,它在大多数任务上优于其他解决方案。这两个模型在相同的训练和调优协议下被与许多现有的架构进行了比较。我们还将最佳的深度学习模型与梯度提升决策树(GBDT)进行了比较,得出结论:目前还没有普遍优越的解决方案。
第一章:引言
由于深度学习在图像、音频和文本等数据领域取得了巨大的成功,研究者们对将这种成功扩展到表格数据问题的兴趣浓厚。在这些问题中,数据点被表示为异构特征的向量,这在工业应用和机器学习竞赛中很常见,其中神经网络面临着来自梯度提升决策树(GBDT)的强大竞争者。除了潜在的更高性能外,使用深度学习处理表格数据还很有吸引力,因为它可以构建多模态管道,其中只有一部分输入是表格数据,而其他部分包括图像、音频等适合深度学习的数据。这样,这些管道可以通过梯度优化对所有模态进行端到端的训练。因此,最近提出了大量的深度学习解决方案,并且新的模型不断出现。然而,由于缺乏既定的基准(如计算机视觉中的ImageNet或自然语言处理中的GLUE),现有的论文使用了不同的数据集进行评估,且提出的深度学习模型往往没有得到充分的比较。因此,从现有的文献来看,尚不清楚哪种深度学习模型总体上表现优于其他模型,GBDT是否被深度学习模型超越。此外,尽管有大量的新架构,该领域仍然缺乏简单可靠的解决方案,能够在各种问题上实现竞争性表现。
第二章:背景和相关工作
FTTransformer原理介绍
FTTransformer是一个可以用于结构化(tabular)数据的分类和回归任务的模型。</br> FT 即 Feature Tokenizer的意思,把结构化数据中的离散特征和连续特征都像单词一样编码成一个向量。</br> 从而可以像对text数据那样 应用 Transformer对 Tabular数据进行特征抽取。</br> 值得注意的是,它对Transformer作了一些微妙的改动以适应 Tabular数据。</br> 例如:去除第一个Transformer输入的LayerNorm层,仿照BERT的设计增加了output token(CLS token) 与features token 一起进行进入Transformer参与注意力计算。
2.1 背景 </br> 本文的背景部分介绍了处理表格数据的挑战和表格数据的特点。表格数据是指具有行和列结构的数据,每列代表一个特征,每行代表一个实例。与图像、文本等其他类型的数据相比,表格数据的特征之间往往缺乏显式的空间或顺序关系,这使得在这种数据上应用深度学习模型存在挑战。
2.2 相关工作 </br> 在这一部分,作者回顾了已有的处理表格数据的方法,包括传统的机器学习方法(如决策树、随机森林、梯度提升机等)和基于神经网络的方法。特别地,作者讨论了Transformer模型的应用,指出传统Transformer模型在表格数据任务上的应用有限,并介绍了如何通过改进模型结构来增强其在表格数据上的表现。
第三章:FT-Transformer 模型
3.1 模型架构 FT-Transformer 是一种专门为表格数据设计的基于Transformer的模型。模型架构的核心是多头自注意力机制,旨在捕捉特征之间的复杂交互。作者详细描述了如何处理数值特征和类别特征,以及如何利用注意力机制来建模特征之间的关系。
3.2 模型优化 在这一部分,作者讨论了针对表格数据的特殊优化措施,包括特征嵌入方法、正则化技术和优化算法的选择。这些优化措施旨在提高模型的泛化能力和训练效率。
3.3 实验结果 作者展示了FT-Transformer在多个公开数据集上的实验结果,并与传统的机器学习方法(如LightGBM)进行了比较。结果显示,FT-Transformer在处理表格数据任务时能够取得有竞争力的性能,尤其是在特征之间存在复杂交互的情况下。
特征嵌入(Feature Embedding)
</br> FT-Transformer首先对表格数据中的每个特征进行嵌入。对于数值型特征,可以使用线性变换来生成嵌入;对于类别型特征,可以使用查找表(embedding lookup)生成嵌入。所有特征嵌入的维度是相同的,这样可以统一输入到Transformer层中。
Transformer层 </br>
在嵌入生成后,FT-Transformer将这些嵌入作为输入,通过一系列的Transformer层进行处理。每个Transformer层包括以下几个部分:
- 多头自注意力机制(Multi-head Self-attention):这种机制可以让模型关注到不同特征之间的相互关系。
- 前馈神经网络(Feedforward Neural Network):用于进一步处理注意力机制的输出。
- 层规范化(Layer Normalization)和残差连接(Residual Connections):这些技术帮助训练更深的网络,同时缓解梯度消失的问题。
输出层 </br>
经过多个Transformer层的处理后,FT-Transformer将特征嵌入合并(例如通过全连接层)来生成最终的输出,用于分类或回归任务。
设计特点 </br>
- 特征层面的操作:FT-Transformer在每个Transformer层对特征进行操作,这与传统的对单个数据点进行处理的方式有所不同。
- 模块化设计:模型的结构借鉴了标准Transformer的模块化设计,使其能够处理表格数据中特定的特征类型和任务。
通过这种设计,FT-Transformer能够有效地捕捉表格数据中特征之间的复杂关系,并在多种任务上展示出优异的性能。
第四章:实验
4.1 数据集和任务
实验中使用了多个不同领域的数据集,包括金融、医疗、广告等领域,旨在评估FT-Transformer在多种情况下的表现。选用的数据集覆盖了不同类型和规模,以确保结果的通用性和可靠性。
4.2 模型设置
对于FT-Transformer的设置,研究者们详细调整了各个超参数以适应不同的数据集和任务。这些设置包括层数、隐藏单元数、头数、学习率等。为了保证公平性,所有模型在相同的硬件和计算资源下进行训练。
4.3 性能对比
在性能对比中,FT-Transformer展示了优越的效果,特别是在处理高维度和大规模数据时。与传统的梯度提升树(如LightGBM、CatBoost等)相比,FT-Transformer在多个任务上均取得了更好的表现,尤其是在处理具有复杂交互关系的数据时。
4.4 结果分析
实验结果表明,FT-Transformer不仅在性能上优于现有的深度学习模型和树模型,而且在处理非结构化数据和特征交互方面也具有显著优势。这验证了Transformer架构在表格数据处理中的潜力和适用性。
4.5 讨论
研究者们讨论了FT-Transformer在不同任务中的表现差异,并分析了可能的原因。尽管FT-Transformer在大多数情况下表现优异,但在某些特定的数据集上,传统模型仍然具有优势。这部分讨论为未来的改进方向提供了参考。
第五章:结论与未来工作
5.1 总结
本研究提出了FT-Transformer,一种适用于表格数据的新型深度学习模型。通过综合多个不同领域的数据集,FT-Transformer展示了其在处理复杂特征交互和高维度数据时的优势。与现有的梯度提升树模型(如LightGBM、CatBoost)相比,FT-Transformer在多个任务上的表现显著更优,表明Transformer架构在表格数据处理中具有巨大的潜力。
5.2 未来工作
尽管FT-Transformer在多个实验中展示了其优越性,但仍有改进空间。未来的研究可能集中在以下几个方面:
- 模型优化:进一步优化模型的架构和超参数,以提升其计算效率和性能。
- 领域适应性:探索FT-Transformer在不同领域中的应用潜力,特别是在需要处理异构数据的场景中。
- 可解释性:增强模型的可解释性,使得其在实际应用中更易于理解和信任。
- 集成学习:将FT-Transformer与其他机器学习模型相结合,以构建更强大的混合模型。
总体而言,FT-Transformer为表格数据的深度学习研究提供了新的方向,并有望在未来的应用中发挥更大的作用。
论文延伸
论文对FT-Transformer和LightGBM进行了比较,以评估两者在处理表格数据任务上的性能。以下是主要比较点:
- 模型结构:
- FT-Transformer:FT-Transformer是基于Transformer的深度学习模型,主要依赖于多头自注意力机制来捕捉数据特征之间的复杂关系。这使得它在处理异构特征和多模态数据时具有优势。
- LightGBM:LightGBM是基于决策树的梯度提升框架,以快速、高效、精确的模型为特点,特别适用于处理大规模数据集和具有高维度特征的数据。
- 性能:
- FT-Transformer:在许多表格数据任务中表现优越,尤其是在需要处理复杂的特征交互的场景中。它的性能受益于Transformer层能够有效地捕捉特征之间的长程依赖关系。
- LightGBM:通常在结构化数据集上表现非常出色,尤其是在传统机器学习任务中,如回归和分类任务。它的优势在于处理数值型和类别型特征的能力,以及对缺失值的鲁棒性。
- 应用场景:
- FT-Transformer:更适合于需要处理复杂特征交互、多模态输入或者异构特征的任务。它也适用于任务中要求端到端训练的情况,比如结合图像、文本或时间序列数据的场景。
- LightGBM:广泛用于实际业务中,特别是在需要高效处理大规模数据的应用场景,如金融、电子商务和推荐系统等。它的简单性和速度使其成为许多机器学习任务中的首选。
- 模型解释性:
- FT-Transformer:作为深度学习模型,通常被视为“黑盒子”,即很难直接解释其内部工作原理。然而,随着自注意力机制和解释方法的发展,模型的部分可解释性正在逐步提高。
- LightGBM:更容易解释,因为它基于决策树,可以通过特征重要性评分等工具了解模型的决策过程。
总结来说,FT-Transformer和LightGBM在不同场景下各有优势。FT-Transformer在处理复杂特征交互和多模态数据方面具有潜力,而LightGBM则在大规模数据处理和业务应用中广泛使用。选择使用哪种模型取决于具体的应用场景、数据特征以及对模型性能和可解释性的要求。
FT-Transformer(Feature Tokenizer + Transformer)是一种针对表格数据的Transformer架构的改编模型。该模型通过将表格数据的特征转化为嵌入表示,然后应用一系列Transformer层来处理这些嵌入。FT-Transformer的设计灵感来源于Transformer在自然语言处理和其他领域的成功应用。以下是FT-Transformer的主要实现细节: